from pathlib import Path
import pprint
import os
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.data import Dataset, InMemoryDataset
from fragment.protein_fragments import constants
from fragment.protein_fragments.process_proteins import CustomData
import pickle
import tqdm

class LBADataset(InMemoryDataset):
    def __init__(self, root,
                 transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    @property
    def raw_file_names(self):
        return [f.name for f in Path(self.raw_dir).iterdir() if f.is_file()]
        # folder_path = self.processed_dir
        # file_names = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
        # return file_names

    def process(self):
        data_list = []
        for file_path in tqdm.tqdm(self.raw_paths):
            try:
                data = torch.load(file_path)
                data_list.append(data)
            except:
                continue
        self.save(data_list, self.processed_paths[0])
        return data_list


if __name__ == "__main__":
    splits = {
        s: LBADataset(constants.PROCESSED_DATA_DIR/s)
        for s in ('training', 'val', 'test')
    }
    pprint.pp(splits)
